from nets.gcn_net import GCNNet
from nets.gat_net import GATNet
from nets.gated_gcn_net import GatedGCNNet

"""
加载模型
可以选择三种模型的任意一种进行训练
GCN GAT GatedGCN
"""


def GCN(net_params):
    return GCNNet(net_params)


def GAT(net_params):
    return GATNet(net_params)


def GatedGCN(net_params):
    return GatedGCNNet(net_params)


# 可以选择任意模型
def gnn_model(MODEL_NAME, net_params):
    models = {
        'GCN': GCN,
        'GAT': GAT,
        'GatedGCN': GatedGCN,
    }

    return models[MODEL_NAME](net_params)
